{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aco import ACO\n",
    "import numpy as np\n",
    "import logging\n",
    "from gen_inst import OPInstance, load_dataset\n",
    "import torch\n",
    "\n",
    "N_ANTS = 30\n",
    "N_ITERATIONS = [1, 10, 30, 50, 80, 100]\n",
    "\n",
    "def heuristics_reevo(prize: np.ndarray, distance: np.ndarray, maxlen: float) -> np.ndarray:\n",
    "\n",
    "    n = prize.shape[0]\n",
    "    heuristics = np.zeros((n, n))\n",
    "\n",
    "    # Calculate the prize-to-distance ratio with a power transformation\n",
    "    prize_distance_ratio = np.power(prize / distance, 3)\n",
    "\n",
    "    # Find the indices of valid edges based on the distance constraint\n",
    "    valid_edges = np.where(distance <= maxlen)\n",
    "\n",
    "    # Assign the prize-to-distance ratio to the valid edges\n",
    "    heuristics[valid_edges] = prize_distance_ratio[valid_edges]\n",
    "\n",
    "    return heuristics\n",
    "\n",
    "def solve(inst: OPInstance):\n",
    "    heu = heuristics_reevo(np.array(inst.prize), np.array(inst.distance), inst.maxlen) + 1e-9\n",
    "    assert tuple(heu.shape) == (inst.n, inst.n)\n",
    "    heu[heu < 1e-9] = 1e-9\n",
    "    heu = torch.from_numpy(heu)\n",
    "    aco = ACO(inst.prize, inst.distance, inst.maxlen, heu, N_ANTS)\n",
    "\n",
    "    results = []\n",
    "    for i in range(len(N_ITERATIONS)):\n",
    "        if i == 0:\n",
    "            obj, _ = aco.run(N_ITERATIONS[i])\n",
    "        else:\n",
    "            obj, _ = aco.run(N_ITERATIONS[i] - N_ITERATIONS[i-1])\n",
    "        results.append(obj.item())\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[*] Average for 50, 1 iterations: 12.462335858585858\n",
      "[*] Average for 50, 10 iterations: 14.484812184343431\n",
      "[*] Average for 50, 30 iterations: 15.15307922979798\n",
      "[*] Average for 50, 50 iterations: 15.31330018939394\n",
      "[*] Average for 50, 80 iterations: 15.379081439393937\n",
      "[*] Average for 50, 100 iterations: 15.39080018939394\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for problem_size in [50]:\n",
    "    dataset_path = f\"./dataset/val{problem_size}_dataset.npz\"\n",
    "    dataset = load_dataset(dataset_path)\n",
    "    n_instances = dataset[0].n\n",
    "    logging.info(f\"[*] Evaluating {dataset_path}\")\n",
    "\n",
    "    objs = []\n",
    "    for i, instance in enumerate(dataset):\n",
    "        obj = solve(instance)\n",
    "        objs.append(obj)\n",
    "    \n",
    "    # Average objective value for all instances\n",
    "    mean_obj = np.mean(objs, axis=0)\n",
    "    for i, obj in enumerate(mean_obj):\n",
    "        print(f\"[*] Average for {problem_size}, {N_ITERATIONS[i]} iterations: {obj}\")\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
